#! /usr/bin/env python
# Copyright (c) 2017, Nefeli Networks, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the names of the copyright holders nor the names of their
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

from __future__ import print_function

"""
Pseduo-module multi-import.

This code is intended to allow you to write, e.g.:

    foo_names = ['foo.' + path[:-3]
        for path in os.listdir('foo') if path.endswith('.py')]
    modname = pm_import('foo', foo_names)

to import the entire contents of directory "foo" as if it
were all just one big file.
"""

import importlib
import sys
import types

__std_skip = set(n for n in sys.modules[__name__].__dict__.keys()
                 if n.startswith('__'))


class Collisions(Exception):
    def __init__(self, *args, **kwargs):
        self.collisions = kwargs.pop('collisions')
        super(Collisions, self).__init__(*args, **kwargs)


def pm_import(mname, iterator, name_filter=None, package=None, override=False):
    """
    Import all files generated by the iterator, copying some of
    its names into a pseudo-module.

    Note: for this to work on a directory, the directory will
    typically have to contain an __init__.py.  If that is not
    empty you probably should be skipping it.

    You get something much like the effect of running:

    if mname in sys.modules:
        return sys.modules[mname]
    module = types.ModuleType(mname)
    for submodule in iterator:
        tmp = importlib.import_module(submodule, package)
        names = filter(name_filter, tmp.__dict__['__all__'])
        module.__dict__.update(names)
    sys.modules[mname] = module
    return module

    (except a missing __all__ gets names as usual).

    If override is True, this really is nearly all this
    does.  If override is False (the default), this also makes
    sure the names are all unique, and if not, raises a
    Collisions error.

    Note: if name_filter is None, all names pass through it.
    A typical actual filter might be:

        name_filter=lambda name: name.startswith('take_me_')

    which would combine all names starting with 'take_me_', such as
    'take_me_home', 'take_me_later', and so on.
    """
    module = sys.modules.get(mname, None)
    if module is not None:
        return module
    allnames = set()
    collisions = set()
    sources = {}
    module = types.ModuleType(mname)
    for submodule in iterator:
        tmp = importlib.import_module(submodule, package)
        try:
            names = tmp.__dict__['__all__']
        except KeyError:
            # standard names like __name__ will always conflict,
            # so always skip them.
            names = set(tmp.__dict__.keys()) - __std_skip
        names = filter(name_filter, names)
        if not override:
            # make set now, we can't iterate over the filter result twice
            names = set(names)

        # Update return value, assuming lack of collisions
        # or desire to keep last-recorded name.
        module.__dict__.update({ k: tmp.__dict__[k] for k in names })

        if override:
            continue

        # convert names to set for fast collision detection
        #names = set(names) -- done above
        collisions |= names & allnames
        allnames |= names
        for k in names:
            # add this submodule to sources for this name
            sources.setdefault(k, []).append(tmp.__name__)

    # If override is True, collisions() is the empty set.
    # If not, it's the set of name/value pairs we want to
    # *keep* from sources.
    if collisions:
        collisions = { k: sources[k] for k in collisions }
        raise Collisions('non-unique names in pm_import({!r})'.format(mname),
                         collisions=collisions)

    sys.modules[mname] = module
    return module
